"""
fashion_mnist_kd.py
-------------------
Teacher–student knowledge distillation with variational quantum classifiers (VQCs)
on the Fashion-MNIST dataset.

* Teacher  : 10-qubit EfficientSU2 ansatz (depth-2)
* Student  :  6-qubit EfficientSU2 ansatz (depth-1)
* Distill  : teacher predicts “soft” labels → student is trained on those labels
* Metrics  : plain classification accuracy on a held-out test set
"""

# ---------------------------------------------------------------
# 0. Imports
# ---------------------------------------------------------------
import numpy as np
from tensorflow.keras.datasets import fashion_mnist
from sklearn.decomposition import PCA
from sklearn.model_selection import train_test_split
from qiskit_aer import AerSimulator
from qiskit.circuit.library import EfficientSU2
from qiskit_machine_learning.neural_networks import SamplerQNN
from qiskit_machine_learning.algorithms import VQC
from qiskit.utils import QuantumInstance


# ---------------------------------------------------------------
# 1. Distillation “server” class
# ---------------------------------------------------------------
class DistillationServer:
    def __init__(
        self,
        num_teacher_qubits: int = 10,
        num_student_qubits: int = 6,
        pca_components: int = 10,
        seed: int = 123,
    ):
        self.seed = seed
        self.n_classes = 10                      # Fashion-MNIST
        self.num_teacher_qubits = num_teacher_qubits
        self.num_student_qubits = num_student_qubits
        self.pca_components = pca_components

        self._load_data()
        self._build_teacher()
        self._build_student()

    # 1.1 ­­­­­­­­­Data loading & preprocessing
    def _load_data(self):
        (X_tr, y_tr), (X_te, y_te) = fashion_mnist.load_data()

        # Flatten, normalise, PCA → angle-encoding range [0, π]
        X_tr = X_tr.reshape(-1, 28 * 28).astype(np.float32) / 255.0
        X_te = X_te.reshape(-1, 28 * 28).astype(np.float32) / 255.0

        pca = PCA(n_components=self.pca_components, random_state=self.seed)
        X_tr = pca.fit_transform(X_tr)
        X_te = pca.transform(X_te)

        X_tr = np.pi * (X_tr - X_tr.min()) / (X_tr.max() - X_tr.min() + 1e-12)
        X_te = np.pi * (X_te - X_tr.min()) / (X_tr.max() - X_tr.min() + 1e-12)

        self.X_train, self.X_test, self.y_train, self.y_test = (
            X_tr, X_te, y_tr, y_te
        )

    # 1.2 ­­­­­­­­­Utility: create a feature-map circuit from a sample vector
    @staticmethod
    def _make_feature_map(x, num_qubits):
        from qiskit import QuantumCircuit
        qc = QuantumCircuit(num_qubits)
        for i, val in enumerate(x[:num_qubits]):
            qc.ry(val, i)
        return qc

    # 1.3 ­­­­­­­­­Build teacher VQC
    def _build_teacher(self):
        ansatz = EfficientSU2(self.num_teacher_qubits, reps=2)
        qnn = SamplerQNN(
            circuit=ansatz,
            input_params=ansatz.parameters[: self.num_teacher_qubits],
            weight_params=ansatz.parameters[self.num_teacher_qubits :],
        )
        self.teacher = VQC(
            feature_map=lambda x: self._make_feature_map(x, self.num_teacher_qubits),
            ansatz=ansatz,
            optimizer="COBYLA",
            quantum_instance=AerSimulator(seed_simulator=self.seed),
            num_classes=self.n_classes,
        )

    # 1.4 ­­­­­­­­­Build student VQC (shallower / fewer qubits)
    def _build_student(self):
        ansatz = EfficientSU2(self.num_student_qubits, reps=1)
        qnn = SamplerQNN(
            circuit=ansatz,
            input_params=ansatz.parameters[: self.num_student_qubits],
            weight_params=ansatz.parameters[self.num_student_qubits :],
        )
        self.student = VQC(
            feature_map=lambda x: self._make_feature_map(x, self.num_student_qubits),
            ansatz=ansatz,
            optimizer="COBYLA",
            quantum_instance=AerSimulator(seed_simulator=self.seed),
            num_classes=self.n_classes,
        )

    # 1.5 ­­­­­­­­­Stage 1 – Train the teacher on true labels
    def train_teacher(self):
        print("⇨ Training teacher VQC …")
        self.teacher.fit(self.X_train, self.y_train)

    # 1.6 ­­­­­­­­­Stage 2 – Generate pseudo-labels with the teacher
    def _pseudo_labels(self, X, temperature: float = 1.0):
        # Teacher energy scores per class → logits → softmax
        logits = []
        for sample in X:
            circuits = self.teacher._neural_network.construct_circuit(sample)
            energies = []
            for circ in circuits:
                # Use exact simulator (statevector) for label generation
                prob_zero = self.teacher.quantum_instance.execute(circ).get_counts().get(
                    "0" * circ.num_qubits, 0
                ) / self.teacher.quantum_instance._run_config.shots
                energies.append((1 - prob_zero) / 2)
            logits.append(-np.array(energies))  # negate so low energy ⇒ high logit
        logits = np.array(logits) / temperature
        exp_l = np.exp(logits - logits.max(axis=1, keepdims=True))
        soft_labels = exp_l / exp_l.sum(axis=1, keepdims=True)
        hard_labels = np.argmax(soft_labels, axis=1)
        return hard_labels  # VQC API needs integer labels; soft targets require custom loss

    # 1.7 ­­­­­­­­­Stage 3 – Train the student on teacher labels
    def train_student(self):
        print("⇨ Generating pseudo-labels from teacher …")
        pseudo_y = self._pseudo_labels(self.X_train)
        print("⇨ Training student VQC on pseudo-labels …")
        self.student.fit(self.X_train, pseudo_y)

    # 1.8 ­­­­­­­­­Evaluation helper
    @staticmethod
    def _accuracy(model, X, y):
        return (model.predict(X) == y).mean()

    def report(self):
        print("\n=== Accuracy ===")
        print(f"Teacher  : {self._accuracy(self.teacher,  self.X_test, self.y_test):.3f}")
        print(f"Student† : {self._accuracy(self.student,  self.X_test, self.y_test):.3f}")
        print("\n† Student trained only on teacher-generated labels.")


# ---------------------------------------------------------------
# 2. Main entry point
# ---------------------------------------------------------------
if __name__ == "__main__":
    server = DistillationServer()
    server.train_teacher()
    server.train_student()
    server.report()
